{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0c419316",
   "metadata": {
    "id": "nasty-intention"
   },
   "source": [
    "# Colab Demo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18421181",
   "metadata": {
    "id": "previous-bacon"
   },
   "source": [
    "## Dependencies and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02c37e1b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T15:42:28.140308Z",
     "start_time": "2023-10-16T15:42:25.626586Z"
    },
    "cellView": "form",
    "id": "complicated-receiver"
   },
   "outputs": [],
   "source": [
    "#@title Install dependencies\n",
    "\n",
    "!pip install -q torchaudio omegaconf\n",
    "\n",
    "import torch\n",
    "from omegaconf import OmegaConf\n",
    "from IPython.display import Audio, display\n",
    "\n",
    "\n",
    "torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',\n",
    "                               'latest_silero_models.yml',\n",
    "                               progress=False)\n",
    "models = OmegaConf.load('latest_silero_models.yml')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb699d3a",
   "metadata": {
    "id": "7oQde2ZNE-ok"
   },
   "source": [
    "## List models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e9772a6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T15:42:33.392937Z",
     "start_time": "2023-10-16T15:42:33.387682Z"
    },
    "id": "pacific-injury"
   },
   "outputs": [],
   "source": [
    "# see latest avaiable models\n",
    "available_models = models.denoise_models.models\n",
    "print(f'Available models {available_models}')\n",
    "\n",
    "for am in available_models:\n",
    "    _models = list(models.denoise_models.get(am).keys())\n",
    "    print(f'Available models for {am}: {_models}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02f1f2de",
   "metadata": {},
   "source": [
    "## SnS Latest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4953abc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T15:42:45.005072Z",
     "start_time": "2023-10-16T15:42:44.567115Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "name = 'small_slow' # 'large_fast', 'small_fast'\n",
    "device = torch.device('cpu')\n",
    "\n",
    "model, samples, utils = torch.hub.load(\n",
    "  repo_or_dir='snakers4/silero-models',\n",
    "  model='silero_denoise',\n",
    "  name=name,\n",
    "  device=device)\n",
    "\n",
    "(read_audio, save_audio, denoise) = utils\n",
    "\n",
    "model.to(device)  # gpu or cpu"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74f3193f",
   "metadata": {},
   "source": [
    "### ex 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8c199a2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T15:43:44.104289Z",
     "start_time": "2023-10-16T15:43:42.036666Z"
    }
   },
   "outputs": [],
   "source": [
    "i = 0\n",
    "torch.hub.download_url_to_file(\n",
    "  samples[i],\n",
    "  dst=f'sample{i}.wav',\n",
    "  progress=True\n",
    ")\n",
    "audio_path = f'sample{i}.wav'\n",
    "audio = read_audio(audio_path).to(device)\n",
    "print('Original: ')\n",
    "display(Audio(audio.cpu(), rate=24000))\n",
    "output = model(audio)\n",
    "save_audio(f'result{i}.wav', output.squeeze(1).cpu())\n",
    "print('Saved: ')\n",
    "display(Audio(f'result{i}.wav'))\n",
    "print('Output: ')\n",
    "display(Audio(output.squeeze(1).cpu(), rate=48000))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bcc9671",
   "metadata": {},
   "source": [
    "### ex 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b611ef65",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T15:44:40.294426Z",
     "start_time": "2023-10-16T15:44:35.029118Z"
    }
   },
   "outputs": [],
   "source": [
    "i = 1\n",
    "torch.hub.download_url_to_file(\n",
    "  samples[i],\n",
    "  dst=f'sample{i}.wav',\n",
    "  progress=True\n",
    ")\n",
    "print('Original: ')\n",
    "display(Audio(f'sample{i}.wav'))\n",
    "output, sr = denoise(model, f'sample{i}.wav', f'result{i}.wav', device='cpu')\n",
    "print('Saved: ')\n",
    "display(Audio(f'result{i}.wav'))\n",
    "print('Output: ')\n",
    "display(Audio(output.cpu(), rate=sr))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "656fd24c",
   "metadata": {
    "id": "Pwg_AYTCmufA"
   },
   "source": [
    "# Minimal Example to Run Locally "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69910a67",
   "metadata": {
    "id": "Phw8EAVBJGW8"
   },
   "source": [
    "## Dependencies and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f177905f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T15:45:36.402889Z",
     "start_time": "2023-10-16T15:45:34.997892Z"
    },
    "cellView": "form",
    "id": "pLTZK0O7JPHT"
   },
   "outputs": [],
   "source": [
    "#@title Install dependencies\n",
    "\n",
    "!pip install -q torch==2.0.0 torchaudio"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e99bee10",
   "metadata": {},
   "source": [
    "## SnS Latest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85354c88",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T15:53:09.177664Z",
     "start_time": "2023-10-16T15:53:02.430939Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torchaudio\n",
    "\n",
    "\n",
    "def read_audio(\n",
    "    path: str,\n",
    "    sampling_rate: int = 24000\n",
    "):\n",
    "    wav, sr = torchaudio.load(path)\n",
    "    if wav.size(0) > 1:\n",
    "        wav = wav.mean(dim=0, keepdim=True)\n",
    "    if sr != sampling_rate:\n",
    "        transform = torchaudio.transforms.Resample(\n",
    "            orig_freq=sr,\n",
    "            new_freq=sampling_rate\n",
    "        )\n",
    "        wav = transform(wav)\n",
    "        sr = sampling_rate\n",
    "    assert sr == sampling_rate\n",
    "    return wav * 0.95\n",
    "\n",
    "\n",
    "device = torch.device('cpu')\n",
    "torch.set_num_threads(4)\n",
    "local_file = 'model.pt'\n",
    "\n",
    "if not os.path.isfile(local_file):\n",
    "    torch.hub.download_url_to_file('https://models.silero.ai/denoise_models/sns_latest.jit',\n",
    "                                   local_file)  \n",
    "\n",
    "model = torch.jit.load(local_file)\n",
    "torch._C._jit_set_profiling_mode(False) \n",
    "torch.set_grad_enabled(False)\n",
    "model.to(device)\n",
    "\n",
    "torch.hub.download_url_to_file('https://models.silero.ai/denoise_models/sample1.wav',\n",
    "                                   dst=f'sample1.wav', progress=True)\n",
    "audio_path = 'sample1.wav'\n",
    "a = read_audio(audio_path)\n",
    "a = a.to(device)\n",
    "out = model(a)\n",
    "print('Original: ')\n",
    "display(Audio(a.cpu(), rate=24000))\n",
    "print('Output: ')\n",
    "display(Audio(out.cpu().squeeze(1).detach().numpy(), rate=48000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d45eff8c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "name": "examples_tts.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
